import glfw
from OpenGL.GL import *
from OpenGL.GL.shaders import compileProgram, compileShader

# -----------------------------
# Compute Shader (simple test kernel)
# -----------------------------
compute_shader_source = """
#version 430
layout(local_size_x = 256) in;

layout(std430, binding = 0) buffer Data {
    float values[];
};

uniform uint start_index;

void main() {
    uint gid = gl_GlobalInvocationID.x + start_index;
    values[gl_GlobalInvocationID.x] = float(gid % 1024u) * 0.001;
}
"""

# -----------------------------
# HDGL Executor
# -----------------------------
class HDGLExecutor:
    def __init__(self):
        if not glfw.init():
            raise RuntimeError("GLFW init failed")

        glfw.window_hint(glfw.VISIBLE, glfw.FALSE)
        self.window = glfw.create_window(1, 1, "hidden", None, None)
        glfw.make_context_current(self.window)

        # Compile compute shader
        self.shader = compileProgram(compileShader(compute_shader_source, GL_COMPUTE_SHADER))

        # Vector size (float32 x4 = 16 bytes)
        self.vector_size = 4

        # Query GPU info
        version = glGetString(GL_VERSION).decode()
        renderer = glGetString(GL_RENDERER).decode()
        print("OpenGL version:", version)
        print("GPU renderer:", renderer)

    def process_virtual_lattice(self, virtual_count, safe_batch=200_000_000):
        """
        Stream through the virtual lattice without exceeding VRAM.
        virtual_count: total number of conceptual vectors
        safe_batch: how many active vectors per pass (fits VRAM)
        """
        processed = 0

        while processed < virtual_count:
            batch = min(safe_batch, virtual_count - processed)
            buffer_size = batch * self.vector_size * 4  # bytes

            # Allocate GPU buffer
            ssbo = glGenBuffers(1)
            glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo)
            glBufferData(GL_SHADER_STORAGE_BUFFER, buffer_size, None, GL_DYNAMIC_DRAW)
            glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, ssbo)

            # Dispatch compute
            glUseProgram(self.shader)
            start_index_location = glGetUniformLocation(self.shader, "start_index")
            glUniform1ui(start_index_location, processed)

            num_groups = (batch + 256 - 1) // 256
            glDispatchCompute(num_groups, 1, 1)
            glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT)

            # Free GPU memory
            glDeleteBuffers(1, [ssbo])

            processed += batch
            print(f"Processed {processed}/{virtual_count} virtual vectors (this batch: {batch})")

        print("✅ Completed processing entire virtual lattice")

# -----------------------------
# Main
# -----------------------------
if __name__ == "__main__":
    # Virtual space ceiling (16M^3)
    virtual_vectors = 16_777_216 ** 3  # ~4.7e19

    executor = HDGLExecutor()
    executor.process_virtual_lattice(virtual_vectors)
